{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multiple choice"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "多肢選択タスクは質問応答に似ていますが、いくつかの候補の回答がコンテキストとともに提供され、正しい回答を選択するようにモデルがトレーニングされる点が異なります。\n",
    "\n",
    "このガイドでは、次の方法を説明します。\n",
    "\n",
    "1. [SWAG](https://huggingface.co/datasets/swag) データセットの「通常」構成で [BERT](https://huggingface.co/google-bert/bert-base-uncased) を微調整して、最適なデータセットを選択します複数の選択肢と何らかのコンテキストを考慮して回答します。\n",
    "2. 微調整したモデルを推論に使用します。\n",
    "\n",
    "始める前に、必要なライブラリがすべてインストールされていることを確認してください。\n",
    "\n",
    "```bash\n",
    "pip install transformers datasets evaluate\n",
    "```\n",
    "\n",
    "モデルをアップロードしてコミュニティと共有できるように、Hugging Face アカウントにログインすることをお勧めします。プロンプトが表示されたら、トークンを入力してログインします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load SWAG dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "まず、🤗 データセット ライブラリから SWAG データセットの「通常」構成をロードします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "swag = load_dataset(\"swag\", \"regular\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、例を見てみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'ending0': 'passes by walking down the street playing their instruments.',\n",
       " 'ending1': 'has heard approaching them.',\n",
       " 'ending2': \"arrives and they're outside dancing and asleep.\",\n",
       " 'ending3': 'turns the lead singer watches the performance.',\n",
       " 'fold-ind': '3416',\n",
       " 'gold-source': 'gold',\n",
       " 'label': 0,\n",
       " 'sent1': 'Members of the procession walk down the street holding small horn brass instruments.',\n",
       " 'sent2': 'A drum line',\n",
       " 'startphrase': 'Members of the procession walk down the street holding small horn brass instruments. A drum line',\n",
       " 'video-id': 'anetv_jkn6uvmqwh4'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "swag[\"train\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ここにはたくさんのフィールドがあるように見えますが、実際は非常に簡単です。\n",
    "\n",
    "- `sent1` と `sent2`: これらのフィールドは文の始まりを示し、この 2 つを組み合わせると `startphrase` フィールドが得られます。\n",
    "- `ending`: 文の終わり方として考えられる終わり方を示唆しますが、正しいのは 1 つだけです。\n",
    "- `label`: 正しい文の終わりを識別します。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次のステップでは、BERT トークナイザーをロードして、文の始まりと 4 つの可能な終わりを処理します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"google-bert/bert-base-uncased\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "作成する前処理関数は次のことを行う必要があります。\n",
    "\n",
    "1. `sent1` フィールドのコピーを 4 つ作成し、それぞれを `sent2` と組み合わせて文の始まりを再現します。\n",
    "2. `sent2` を 4 つの可能な文末尾のそれぞれと組み合わせます。\n",
    "3. これら 2 つのリストをトークン化できるようにフラット化し、その後、各例に対応する `input_ids`、`attention_mask`、および `labels` フィールドが含まれるように非フラット化します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ending_names = [\"ending0\", \"ending1\", \"ending2\", \"ending3\"]\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    first_sentences = [[context] * 4 for context in examples[\"sent1\"]]\n",
    "    question_headers = examples[\"sent2\"]\n",
    "    second_sentences = [\n",
    "        [f\"{header} {examples[end][i]}\" for end in ending_names] for i, header in enumerate(question_headers)\n",
    "    ]\n",
    "\n",
    "    first_sentences = sum(first_sentences, [])\n",
    "    second_sentences = sum(second_sentences, [])\n",
    "\n",
    "    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)\n",
    "    return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "データセット全体に前処理関数を適用するには、🤗 Datasets `map` メソッドを使用します。 `batched=True` を設定してデータセットの複数の要素を一度に処理することで、`map` 関数を高速化できます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_swag = swag.map(preprocess_function, batched=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`DataCollatorForMultipleChoice` は、すべてのモデル入力を平坦化し、パディングを適用して、結果を非平坦化します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForMultipleChoice\n",
    "collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "トレーニング中にメトリクスを含めると、多くの場合、モデルのパフォーマンスを評価するのに役立ちます。 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) ライブラリを使用して、評価メソッドをすばやくロードできます。このタスクでは、[accuracy](https://huggingface.co/spaces/evaluate-metric/accuracy) メトリクスを読み込みます (🤗 Evaluate [クイック ツアー](https://huggingface.co/docs/evaluate/a_quick_tour) を参照してください) ) メトリクスの読み込みと計算方法の詳細については、次を参照してください)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "accuracy = evaluate.load(\"accuracy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、予測とラベルを `compute` に渡して精度を計算する関数を作成します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    predictions = np.argmax(predictions, axis=1)\n",
    "    return accuracy.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "これで`compute_metrics`関数の準備が整いました。トレーニングをセットアップするときにこの関数に戻ります。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip>\n",
    "\n",
    "[Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) を使用したモデルの微調整に慣れていない場合は、[ここ](https://huggingface.co/docs/transformers/main/ja/tasks/../training#train-with-pytorch-trainer) の基本的なチュートリアルをご覧ください。\n",
    "\n",
    "</Tip>\n",
    "\n",
    "これでモデルのトレーニングを開始する準備が整いました。 [AutoModelForMultipleChoice](https://huggingface.co/docs/transformers/main/ja/model_doc/auto#transformers.AutoModelForMultipleChoice) を使用して BERT をロードします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer\n",
    "\n",
    "model = AutoModelForMultipleChoice.from_pretrained(\"google-bert/bert-base-uncased\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "この時点で残っている手順は次の 3 つだけです。\n",
    "\n",
    "1. [TrainingArguments](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.TrainingArguments) でトレーニング ハイパーパラメータを定義します。唯一の必須パラメータは、モデルの保存場所を指定する `output_dir` です。 `push_to_hub=True`を設定して、このモデルをハブにプッシュします (モデルをアップロードするには、Hugging Face にサインインする必要があります)。各エポックの終了時に、[Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) は精度を評価し、トレーニング チェックポイントを保存します。\n",
    "2. トレーニング引数を、モデル、データセット、トークナイザー、データ照合器、および `compute_metrics` 関数とともに [Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) に渡します。\n",
    "3. [train()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.train) を呼び出してモデルを微調整します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"my_awesome_swag_model\",\n",
    "    eval_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    load_best_model_at_end=True,\n",
    "    learning_rate=5e-5,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=3,\n",
    "    weight_decay=0.01,\n",
    "    push_to_hub=True,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_swag[\"train\"],\n",
    "    eval_dataset=tokenized_swag[\"validation\"],\n",
    "    processing_class=tokenizer,\n",
    "    data_collator=collator,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "トレーニングが完了したら、 [push_to_hub()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.push_to_hub) メソッドを使用してモデルをハブに共有し、誰もがモデルを使用できますように。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip>\n",
    "\n",
    "複数選択用にモデルを微調整する方法の詳細な例については、対応するセクションを参照してください。\n",
    "[PyTorch ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multiple_choice.ipynb)\n",
    "または [TensorFlow ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multiple_choice-tf.ipynb)。\n",
    "\n",
    "</Tip>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "モデルを微調整したので、それを推論に使用できるようになりました。\n",
    "\n",
    "いくつかのテキストと 2 つの回答候補を考えてください。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"France has a bread law, Le Décret Pain, with strict rules on what is allowed in a traditional baguette.\"\n",
    "candidate1 = \"The law does not apply to croissants and brioche.\"\n",
    "candidate2 = \"The law applies to baguettes.\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "各プロンプトと回答候補のペアをトークン化し、PyTorch テンソルを返します。いくつかの`lables`も作成する必要があります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"my_awesome_swag_model\")\n",
    "inputs = tokenizer([[prompt, candidate1], [prompt, candidate2]], return_tensors=\"pt\", padding=True)\n",
    "labels = torch.tensor(0).unsqueeze(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "入力とラベルをモデルに渡し、`logits`を返します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForMultipleChoice\n",
    "\n",
    "model = AutoModelForMultipleChoice.from_pretrained(\"my_awesome_swag_model\")\n",
    "outputs = model(**{k: v.unsqueeze(0) for k, v in inputs.items()}, labels=labels)\n",
    "logits = outputs.logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最も高い確率でクラスを取得します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted_class = logits.argmax().item()\n",
    "predicted_class"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
